# Install whoosh and whoosh_utils library
!pip install /kaggle/input/uspto-whoosh-reloaded-2-7-5-patched/Whoosh_Reloaded-2.7.5-py2.py3-none-any.whl
!sed 's:/kaggle/input/whoosh-wheel-2-7-4/Whoosh-2.7.4-py2.py3-none-any.whl:whoosh-reloaded==2.7.5:g' /kaggle/usr/lib/whoosh_utils/whoosh_utils.py > whoosh_utils.py
import os
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
import keras_nlp
import keras
import numpy as np
import pandas as pd
from tqdm import tqdm
import gc
import re
import whoosh_utils
import whoosh
print("Keras:", keras.__version__)
print("KerasNLP:", keras_nlp.__version__)
class CFG:
seed = 42
dataset_path = "/kaggle/input/uspto-explainable-ai"
preset = "gemma_1.1_instruct_2b_en" # name of pretrained Gemma
input_length = 1024 # max size of input sequence for training
output_length = 1200 # max size of output sequence
num_neighbors = 2 # how many neighbour patents to consider
keras.utils.set_random_seed(CFG.seed)
# Read the CSV file into a DataFrame with specific columns
test_df = pd.read_csv(f"{CFG.dataset_path}/test.csv")
test_df = test_df.iloc[:, :CFG.num_neighbors+1]
target_cols = list(test_df.columns[1:])
# Merge metadata of the patents
meta_df = pd.read_parquet(f"{CFG.dataset_path}/patent_metadata.parquet")
test_df = test_df.merge(meta_df, on="publication_number", how="left")
# Merge Title and Abstract of the patennts
patent_df = pd.read_parquet("/kaggle/input/uspto-all-patents-after-1975/all_patents.parquet")
test_df = test_df.merge(patent_df, on="publication_number")
# Fill NaN values
test_df["title"] = test_df["title"].fillna("")
test_df["abstract"] = test_df["abstract"].fillna("")
# Merge Title and Abstract of the neighbour patents
for i in range(CFG.num_neighbors):
test_df = test_df.merge(
patent_df,
left_on=target_cols[i],
right_on="publication_number",
how="left",
suffixes=("", f"_{i}"),
)
# Fill NaN values
test_df[f"title_{i}"] = test_df[f"title_{i}"].fillna("")
test_df[f"abstract_{i}"] = test_df[f"abstract_{i}"].fillna("")
# Drop extra publication_number column from merges
test_df = test_df.drop(columns=[f"publication_number_{i}"])
# Reset index order as it will be used later for iteration
test_df = test_df.reset_index(drop=True)
# Clean up memory
del meta_df, patent_df
gc.collect()
test_df.head()
prompt_template = "Task:\nAnalyze and compare the given two patent abstracts and titles, and identify the common or similar query keywords that should yield these two patents when searched in the United States Patent and Trademark Office (USPTO) database.\n\nInstructions:\n1. Carefully read and understand the provided 'Patent 1' and 'Patent 2' titles and abstracts below.\n2. Identify the key terms, concepts, and components that are either common or similar in both patent titles and abstracts.\n3. In the 'Keywords' section below, write the common or similar keywords, separating each keyword with a semicolon (';') and a space (' '). Here is an example response, 'keyword1; keyword2; keyword3_1 keyword3_2'.\n4. Do not add any additional narratives or text before or after the keywords.\n\nPatent 1:\n* Title: {title_a}\n* Abstract: {abstract_a}\n\nPatent 2:\n* Title: {title_b}\n* Abstract: {abstract_b}\n\nKeywords:"
chat_template = f"<start_of_turn>user\n{prompt_template}<end_of_turn>\n<start_of_turn>model\n"
print(prompt_template)
def create_prompt(row, neighbor_idx):
prompt = chat_template.format(title_a=row["title"], abstract_a=row["abstract"],
title_b=row[f"title_{neighbor_idx}"], abstract_b=row[f"abstract_{neighbor_idx}"])
return prompt
prompt_sample = create_prompt(test_df.iloc[2], 0)
print(prompt_sample)
# Declare the model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
# Set input length of small to keep the memory and latency cost small
gemma_lm.preprocessor.sequence_length = CFG.input_length
def generate_keyword(row, neighbor_idx):
# Check if any title or abstract is an empty string
fields = [
"title",
"abstract",
f"title_{neighbor_idx}",
f"abstract_{neighbor_idx}",
]
if all(row[field] == "" for field in fields):
return [""]
# Create prompt
prompt = create_prompt(row, neighbor_idx)
try:
# Generate output from model
output = gemma_lm.generate(prompt, max_length=CFG.output_length)
# Extract keyword from model output
keyword = decode_output(output, prompt)
except:
keyword = [""]
return keyword
def decode_output(output, prompt):
# Remove input prompt from model output
answer = output.replace(prompt, "").strip()
# Avoid edge case when output_max_length < model_output
if "Title:" in answer and "Abstract:" in answer:
return [""]
# Filter out possible unwanted output text
for x in ["Keywords:", "solution:", "Solution", "**", "\n\n", "[", "]"]:
answer = answer.replace(x, "").strip()
# Create list of keywords using possible delimiters
for sep in [";\n-", ",\n-", "\n-", ";\n", ";\n*", ",\n*", ",\n", "\n*", ",", ";"]:
if sep in answer:
answer = answer.strip(sep).strip().split(sep + " ")
# Final filtering: remove '*', '.', and any keywords length > 40
keywords = [x.replace("*", "").replace(".", "") for x in set(answer) if len(x) < 40]
# If there is no keywords found, then enter a empty string as keyword
if not len(keywords):
keywords = [""]
return keywords
# Keywords for all patents
keywords_all = []
for i, row in tqdm(test_df.iterrows(), total=test_df.shape[0]):
# Keywords for one patent
keywords = []
# Iteratively create keywords for each (patent, neighbour) pair
for i in range(CFG.num_neighbors):
keywords += generate_keyword(row, neighbor_idx=i)
# Remove duplicate keywords
keywords = list(set(keywords))
# Merge keywords
keywords_all.append(keywords)
_ = [print(f"Keywords {i}: {q}", end="\n\n") for i, q in enumerate(keywords_all[:3])]
BRS_STOPWORDS = ['an', 'are', 'by', 'for', 'if', 'into', 'is', 'no', 'not', 'of', 'on', 'such',
'that', 'the', 'their', 'then', 'there', 'these', 'they', 'this', 'to', 'was', 'will', 'and', 'or']
NUMBER_REGEX = re.compile(r'^(\d+|\d{1,3}(,\d{3})*)(\.\d+)?$')
class NumberFilter(whoosh.analysis.Filter):
def __call__(self, tokens):
for t in tokens:
if not NUMBER_REGEX.match(t.text):
yield t
custom_analyzer = whoosh.analysis.StandardAnalyzer(stoplist=BRS_STOPWORDS) | NumberFilter()
it = custom_analyzer("device, 1.023, machine, that, learning, there")
[token.text for token in it]
query_validator = whoosh_utils.QueryValidator()
def validate_query(query):
query = "ti:device" if not len(query) or not isinstance(query, str) else query
try:
query_validator.validate_query(query)
except:
query = "ti:device"
return query
validate_query("device OR machine") # query is valid
validate_query("(device OR machine") # query is invalid due to missing ')' thus returns default query
queries = []
for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
# Create query from cpc_codes
cpc = row["cpc_codes"]
query_cpc = f"cpc:({' OR '.join(cpc[:15])})" if len(cpc) else ""
try:
# Analyze the keywords
keywords_str = ", ".join(keywords_all[i])
tokens = list(set([token.text for token in custom_analyzer(keywords_str)]))
# Reduce the keywords if number of query tokens > 50
while len(tokens):
# Create query from keywords
query_keywords = f"({' OR '.join(tokens)})"
# Merge quries from keywords and cpc_codes
query_check = f"detd:{query_keywords}" + (f" AND {query_cpc}" if len(query_cpc) else "")
# Return query if number of query tokens is okay
if whoosh_utils.count_query_tokens(query_check) < 50:
query = query_check
break
# Reduce keywords if number query is not okay
tokens.pop()
except:
query = query_cpc
# Final query validation
query = validate_query(query)
queries.append(query)
_ = [print(f"Query {i}: {q}", end="\n\n") for i, q in enumerate(queries[:3])]
test_df["query"] = queries
pred_df = test_df[["publication_number", "query"]]
pred_df.to_csv("submission.csv", index=False)
pred_df.head()